Problem


As it can be seen in 20_04_08_Ridge.html, there is a relation between the score the model assigns to a gene and the gene’s mean level of expression. This is a problem because we had previously discovered a bias in the SFARI scores related to mean level of expression (Preprocessing/Gandal/AllRegions/RMarkdowns/20_04_03_SFARI_genes.html), which means that this could be a confounding factor in our model and the reason why it seems to perform well, so we need to remove this bias to recover the true biological signal that is mixed with it and improve the quality of our model.


Weighting Technique


General idea:


train model with equal weights for all samples

for l in loop:
  calculate bias
  correct weights to reduce bias
  retrain model
  
Return last model
  


Pseudocode:


Parameters:


Pseudocode:

lambda = 0
w = [1, ..., 1]
c = std(meanExpr(D))

h  = train classifier H with lambda and w

for t in 1,,,T do
  bias = <h(x), c(x)>
  update lambda to lambda - eta*bias
  update weights_hat to exp(lambda*mean(c))
  update weights to w_hat/(1+w_hat) if y_i=1, 1/(1+w_hat) if y_i=0
  update h with new weights
  
Return h
library(tidyverse)
library(knitr)
library(plotly) ; library(viridis) ; library(gridExtra) ; library(RColorBrewer) ; library(corrplot) ; library(reshape2)
library(biomaRt)
library(Rtsne)
library(caret) ; library(ROCR) ; library(car)
library(polycor)
library(expss) ; library(knitr)

SFARI_colour_hue = function(r) {
  pal = c('#FF7631','#FFB100','#E8E328','#8CC83F','#62CCA6','#59B9C9','#b3b3b3','#808080','gray','#d9d9d9')[r]
}

Load data

# Clusterings
clustering_selected = 'DynamicHybridMergedSmall'
clusterings = read_csv('./../Data/clusters.csv')
clusterings$Module = clusterings[,clustering_selected] %>% data.frame %>% unlist %>% unname

assigned_module = data.frame('ID' = clusterings$ID, 'Module' = clusterings$Module)

# Original dataset
original_dataset = read.csv(paste0('./../Data/dataset_', clustering_selected, '.csv'), row.names=1)

# Model dataset
# load('./../Data/LR_model.RData')
# 
# Regression data
load(file='./../Data/Ridge_model_robust.RData')
test_set = predictions

# Mean Expression data
load('./../Data/preprocessed_data.RData')
datExpr = datExpr %>% data.frame
DE_info = DE_info %>% data.frame

# Dataset created with DynamicTreeMerged algorithm
clustering_selected = 'DynamicHybridMergedSmall'
original_dataset = read.csv(paste0('./../Data/dataset_', clustering_selected, '.csv'), row.names=1)

# Add gene symbol
getinfo = c('ensembl_gene_id','external_gene_id')
mart = useMart(biomart='ENSEMBL_MART_ENSEMBL', dataset='hsapiens_gene_ensembl',
               host='feb2014.archive.ensembl.org')
gene_names = getBM(attributes=getinfo, filters=c('ensembl_gene_id'), values=rownames(dataset), mart=mart)

rm(dds, datGenes, datMeta, clustering_selected, clusterings, mart, getinfo, fit, train_set, negative_set, predictions)


Remove Bias


Demographic Parity


Using Demographic Parity as a measure of bias: A fair classifier h should make positive predictions each segment \(G\) of the population at the same rate as in all of the population

This definition is for discrete segments of the population. Since our bias is found across all the population but in different measures depending on the mean level of expression of the gene, we have to adapt this definition to a continuous bias scenario

Demographic Parity for our problem: A fair classifier h should make positive predictions on genes with a certail mean level of expression at the same rate as in all of the genes in the dataset


Demographic Parity bias metric


The original formula for the Demographic Parity bias is

  • $c(x,0) = 0 $ when the prediction is negative

  • \(c(x,1) = \frac{g(x)}{Z_G}-1\) when the prediction is positive. Where \(g(x)\) is the Kronecker delta to indicate if the sample belongs to the protected group and \(Z_G\) is the proportion of the population that belongs to the group we want to protect against bias


Using this definitions in our problem:

\(g(x):\) Since all our samples belong to the protected group, this would always be 1

\(Z_G:\) Since all of our samples belong to the protected group, this would also always be 1

So our measure of bias \(c(x,1) = \frac{1}{1}-1 = 0\) for all samples. This doesn’t work, so we need to adapt it to our continous case


Adaptation of the bias metric


We can use \(c(x,1) = std(meanExpr(x))\) as the constraint function, this way, when we calculate the bias of the dataset:

\(h(x)\cdot c(x)\) will only be zero if the positive samples are balanced around the mean expression, and the sign of the bias will indicate the direction of the bias


Notes:

  • Running the model several times with different test/train partitions to obtain more robust results
### DEFINE FUNCTIONS

# Create positive training set including all SFARI scores
positive_sample_balancing_SFARI_scores = function(p, seed){
  
  set.seed(seed)
  positive_train_idx = c()
  
  for(score in 1:6){
    score_genes = rownames(original_dataset)[rownames(original_dataset) %in% rownames(dataset) & original_dataset$gene.score == score]
    score_idx = which(rownames(dataset) %in% score_genes)
    score_train_idx = sample(score_idx, size = ceiling(p*length(score_idx)))
    
    positive_train_idx = c(positive_train_idx, score_train_idx)
  }
  
  return(positive_train_idx)
}

create_train_test_sets = function(p, over_sampling_fold, seed){
  
  ### CREATE POSITIVE TRAINING SET (balancing SFARI scores and over-sampling)
  positive_train_idx = positive_sample_balancing_SFARI_scores(p, seed)
  add_obs = sample(positive_train_idx, size = ceiling(over_sampling_fold*length(positive_train_idx)), replace=TRUE)
  positive_train_idx = c(positive_train_idx, add_obs)
  
  
  ### CREATE NEGATIVE TRAINING SET
  negative_idx = which(!dataset$SFARI)
  negative_train_idx = sample(negative_idx, size = length(positive_train_idx))
  
  
  ### CREATE TRAIN AND TEST SET
  train_set = dataset[sort(c(positive_train_idx, negative_train_idx)),]
  test_set = dataset[-unique(c(positive_train_idx, negative_train_idx)),]
  
  return(list('train_set' = train_set, 'test_set' = test_set))
  
}

run_model = function(p, over_sampling_fold, seed){
  
  # Create train and test sets
  train_test_sets = create_train_test_sets(p, over_sampling_fold, seed)
  train_set = train_test_sets[['train_set']]
  test_set = train_test_sets[['test_set']]
  
  # Train model
  train_set$SFARI = train_set$SFARI %>% as.factor
  
  # Initial parameters
  set.seed(seed)
  eta = 0.5
  lambda = 0
  Loops = 50
  w = rep(1, nrow(train_set))
  h = train(SFARI ~., data = train_set, method = 'glmnet', trControl = trainControl('cv', number = 10),
                tuneGrid = expand.grid(alpha = 0, lambda = 10^seq(1, -3, by = -.1)))
  
  mean_expr = data.frame('ID' = rownames(datExpr), 'meanExpr' = rowMeans(datExpr)) %>%
              filter(ID %in% rownames(train_set)) %>% right_join(data.frame('ID' = substr(rownames(train_set),1,15)), by = 'ID') %>%
              mutate('meanExpr_std' = (meanExpr-mean(meanExpr))/sd(meanExpr))
  
  # Track behaviour of plot
  bias_vec = c()
  acc_vec = c()
  
  for(l in 1:Loops){
    
    # Calculate bias for positive predicted samples
    bias = mean(mean_expr$meanExpr_std[predict(h,train_set) %>% as.logical])
    
    # Update weights
    lambda = lambda - eta*bias
    w_hat = exp(lambda*mean_expr$meanExpr_std)
    w = 1/(1+w_hat)
    w[train_set$SFARI %>% as.logical] = w[train_set$SFARI %>% as.logical] * w_hat[train_set$SFARI %>% as.logical]
    
    # Update tracking vars
    bias_vec = c(bias_vec, bias)
    acc_vec = c(acc_vec, mean(predict(h,train_set) == train_set$SFARI))
    
    # Update h
    h = train(SFARI ~., data = train_set, method = 'glmnet', weights = w, trControl = trainControl('cv', number = 10),
                tuneGrid = expand.grid(alpha = 0, lambda = 10^seq(1, -3, by = -.1)))
  }
  
  # Predict labels in test set
  predictions = h %>% predict(test_set, type='prob')
  preds = data.frame('ID'=rownames(test_set), 'prob'=predictions$`TRUE`) %>% mutate(pred = prob>0.5)
  
  # Measure performance of the model
  acc = mean(test_set$SFARI==preds$pred)
  pred_ROCR = prediction(preds$prob, test_set$SFARI)
  AUC = performance(pred_ROCR, measure='auc')@y.values[[1]]
  
  return(list('preds' = preds, 'lambda' = lambda, 'bias_vec' = bias_vec, 'acc_vec' = acc_vec, 'acc' = acc, 'AUC' = AUC))
}


### RUN MODEL

# Parameters
p = 0.8
over_sampling_fold = 3
n_iter = 50
seeds = 123:(123+n_iter-1)

# Store outputs
acc = c()
AUC = c()
predictions = data.frame('ID' = rownames(dataset), 'SFARI' = dataset$SFARI, 'prob' = 0, 'pred' = 0, 'n' = 0)

for(seed in seeds){
  
  # Run model
  model_output = run_model(p, over_sampling_fold, seed)
  
  # Update outputs
  acc = c(acc, model_output[['acc']])
  AUC = c(AUC, model_output[['AUC']])
  preds = model_output[['preds']]
  update_preds = preds %>% dplyr::select(-ID) %>% mutate(n=1)
  predictions[predictions$ID %in% preds$ID, c('prob','pred','n')] = predictions[predictions$ID %in% preds$ID, c('prob','pred','n')] + update_preds
  
  if(seed == seeds[1]) {# Save the results from the bias correction iterations from one of the runs to analyse later
    lambda = model_output[['lambda']]
    bias_vec = model_output[['bias_vec']]
    acc_vec = model_output[['acc_vec']]
  }
}

predictions = predictions %>% mutate(prob = prob/n, pred_count = pred, pred = prob>0.5)


rm(p, over_sampling_fold, seeds, update_preds, positive_sample_balancing_SFARI_scores, create_train_test_sets, run_model)

The bias decreases until it oscilates around zero and the accuracy is not affected much

plot_info = data.frame('iter' = 1:length(bias_vec), 'bias' = bias_vec, 'accuracy' = acc_vec) %>% melt(id.vars = 'iter')

plot_info %>% ggplot(aes(x=iter, y=value, color = variable)) + geom_line() + theme_minimal()

  • Since the bias increases the probability of being classified as 1 for genes with higher levels of expression, as the level of expression of a gene increases, the algorithm:

    • Increases the weight of genes with a negative label

    • Decreases the weight of genes with a positibe label

mean_expr = data.frame('ID' = rownames(datExpr), 'meanExpr' = rowMeans(datExpr)) %>%
            left_join(predictions, by = 'ID') %>% filter(n>0) %>%
            mutate('meanExpr_std' = (meanExpr-mean(meanExpr))/sd(meanExpr))

w_hat = exp(lambda*mean_expr$meanExpr_std) # inverso a mean expr
w0 = 1/(1+w_hat) # prop a mean expr
w = 1/(1+w_hat)
w[mean_expr$SFARI %>% as.logical] = w[mean_expr$SFARI %>% as.logical]*w_hat[mean_expr$SFARI %>% as.logical] # inv mean expr Positives, prop Negatives
plot_data = data.frame(meanExpr = mean_expr$meanExpr, w_hat = w_hat, w0 = w0, w = w, SFARI = mean_expr$SFARI, pred = mean_expr$pred)

plot_data %>% ggplot(aes(meanExpr, w, color = SFARI)) + geom_point(alpha = 0.3) + ylab('weight') + xlab('Mean Expression') + 
              ggtitle('Weights of the final model') + ylim(c(0,1)) + theme_minimal()


Results


The relation is not completely gone, there seems to be a negative relation for the genes with the lowest levels of expression.

Even though the trend line is not as flat as with the first method, we are not fixing this directly as we were doing before, this is now just a consequence of the corrections we did inside of the model, so it makes sense for it to be less exact than before

test_set_backup = test_set

test_set = test_set %>% left_join(predictions %>% mutate(corrected_score = prob, corrected_pred = pred) %>% 
                                  dplyr::select(ID, corrected_score, corrected_pred), by = 'ID')

# # Correct Bias
# predictions = h %>% predict(test_set, type='prob')
# test_set$corrected_score = predictions$`TRUE`
# test_set$corrected_pred = test_set$corrected_score>0.5

# Plot results
plot_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% 
            right_join(test_set, by='ID')

plot_data %>% ggplot(aes(meanExpr, corrected_score)) + geom_point(alpha=0.1, color='#0099cc') +
              geom_smooth(method='gam', color='gray', alpha=0.2) + ylab('Corrected Score') + xlab('Mean Expression') +
              theme_minimal() + ggtitle('Mean expression vs Model score corrected using adjusted weights')




Performance Metrics


Confusion matrix

conf_mat = test_set %>% apply_labels(SFARI = 'Actual Labels', 
                                     corrected_score = 'Corrected Score', 
                                     corrected_pred = 'Corrected Label Prediction')

cro(conf_mat$SFARI, list(conf_mat$corrected_pred, total()))
 Corrected Label Prediction     #Total 
 FALSE   TRUE   
 Actual Labels 
   FALSE  8746 6391   15137
   TRUE  357 540   897
   #Total cases  9103 6931   16034
rm(conf_mat)

Accuracy


The accuracy was expected to decrease because the bias was helping classify samples correctly, but for the wrong reasons

cat(paste0('ACCURACY: mean = ', round(mean(acc),4), ' SD = ', round(sd(acc),4)))
## ACCURACY: mean = 0.5762 SD = 0.0076
old_acc = mean(test_set$SFARI==(test_set$prob>0.5))
acc = mean(acc)
cat(paste0('Accuracy decreased ',round(old_acc-acc,4), ' points'))
## Accuracy decreased 0.0642 points
rm(old_acc)

ROC Curve


cat(paste0('AUC:      mean = ', mean(AUC), ' SD = ', sd(AUC)))
## AUC:      mean = 0.622004444364543 SD = 0.0209978310442912
pred_ROCR = prediction(test_set$corrected_score, test_set$SFARI)

roc_ROCR = performance(pred_ROCR, measure='tpr', x.measure='fpr')
auc = performance(pred_ROCR, measure='auc')@y.values[[1]]

plot(roc_ROCR, main=paste0('ROC curve (AUC=',round(auc,2),')'), col='#009999')
abline(a=0, b=1, col='#666666')

rm(roc_ROCR, auc)

Lift Curve


lift_ROCR = performance(pred_ROCR, measure='lift', x.measure='rpp')
plot(lift_ROCR, main='Lift curve', col='#86b300')

rm(lift_ROCR, pred_ROCR)

Analyse Model

Looks very similar to before, the means of each group are a bit closer together

plot_data = test_set %>% filter(!is.na(corrected_score)) %>% dplyr::select(corrected_score, SFARI)

ggplotly(plot_data %>% ggplot(aes(corrected_score, fill=SFARI, color=SFARI)) + geom_density(alpha=0.3) + xlab('Score') +
         geom_vline(xintercept = mean(plot_data$corrected_score[plot_data$SFARI]), color = '#00C0C2', linetype = 'dashed') +
         geom_vline(xintercept = mean(plot_data$corrected_score[!plot_data$SFARI]), color = '#FF7371', linetype = 'dashed') +
         theme_minimal() + ggtitle('Model score distribution by SFARI Label'))

The positive relation between SFARI scores and Model scores is still there but is not as strong as before

plot_data = test_set %>% dplyr::select(ID, corrected_score) %>%
            left_join(original_dataset %>% mutate(ID=rownames(original_dataset)), by='ID') %>%
            dplyr::select(ID, corrected_score, gene.score) %>% apply_labels(gene.score='SFARI Gene score')

cro(plot_data$gene.score)
 #Total 
 SFARI Gene score 
   1  25
   2  64
   3  191
   4  432
   5  163
   6  22
   None  15137
   #Total cases  16034
ggplotly(plot_data %>% ggplot(aes(gene.score, corrected_score, fill=gene.score)) + geom_boxplot() + 
              scale_fill_manual(values=SFARI_colour_hue(r=c(1:6,8,7))) + 
              ggtitle('Distribution of the Model scores by SFARI score') +
              xlab('SFARI score') + ylab('Model score') + theme_minimal())

Print genes with highest corrected scores in test set

test_set %>% dplyr::select(ID, corrected_score, SFARI) %>% arrange(desc(corrected_score)) %>% top_n(50, wt=corrected_score) %>%
             left_join(original_dataset %>% mutate(ID=rownames(original_dataset)), by='ID')  %>% 
             left_join(gene_names, by = c('ID'='ensembl_gene_id')) %>%
             dplyr::rename('GeneSymbol' = external_gene_id, 'Probability' = corrected_score,
                           'ModuleDiagnosis_corr' = MTcor, 'GeneSignificance' = GS) %>%
             mutate(ModuleDiagnosis_corr = round(ModuleDiagnosis_corr,4), Probability = round(Probability,4), 
                    GeneSignificance = round(GeneSignificance,4)) %>%
             dplyr::select(GeneSymbol, GeneSignificance, ModuleDiagnosis_corr, Module, Probability, gene.score) %>%
             kable(caption = 'Genes with highest model probabilities from the test set')
Genes with highest model probabilities from the test set
GeneSymbol GeneSignificance ModuleDiagnosis_corr Module Probability gene.score
HIVEP2 0.0165 -0.9514 #00C0AF 0.7978 None
SNX25 0.0951 -0.0094 #00A7FF 0.7919 None
AHI1 0.0143 -0.0094 #00A7FF 0.7818 None
CNTNAP5 -0.0796 -0.9514 #00C0AF 0.7782 4
ARHGAP20 -0.0230 -0.9514 #00C0AF 0.7734 None
SORCS3 0.0790 -0.0094 #00A7FF 0.7690 None
HECTD2 0.1862 -0.0094 #00A7FF 0.7675 None
CLMP 0.0344 -0.0094 #00A7FF 0.7674 None
ADAMTSL1 0.0663 -0.9514 #00C0AF 0.7671 None
PLXNC1 -0.0088 -0.0094 #00A7FF 0.7652 None
CELF2 -0.0605 -0.9514 #00C0AF 0.7640 None
KIAA1217 -0.1868 -0.9514 #00C0AF 0.7630 None
ITGAM 0.2983 0.7287 #39B600 0.7627 None
KCNJ6 -0.1379 -0.9514 #00C0AF 0.7625 None
PLD4 0.1010 0.7287 #39B600 0.7621 None
MOG -0.0290 -0.9514 #00C0AF 0.7614 None
CDH18 0.1860 -0.9514 #00C0AF 0.7585 None
PACSIN3 -0.0825 -0.9514 #00C0AF 0.7585 None
ITGAX -0.0073 -0.0094 #00A7FF 0.7569 None
DTX4 -0.0102 -0.6031 #00BA38 0.7561 None
CSMD3 0.0768 0.2525 #C99800 0.7556 None
FOLH1 -0.0506 -0.0094 #00A7FF 0.7548 5
EGR1 0.2192 -0.0094 #00A7FF 0.7546 None
CSMD1 -0.2928 -0.9514 #00C0AF 0.7543 4
DOCK2 0.4756 0.7287 #39B600 0.7538 None
SIK2 0.1631 0.1127 #FF62BC 0.7529 None
FMNL1 -0.2223 -0.6031 #00BA38 0.7521 None
SERPING1 0.2991 0.8742 #E58700 0.7516 None
EVL -0.1527 -0.9514 #00C0AF 0.7514 None
TBL1X 0.3339 0.7287 #39B600 0.7498 4
SCN3B 0.0668 0.1127 #FF62BC 0.7485 None
POC1B 0.1540 -0.0094 #00A7FF 0.7483 None
LRRC7 0.1862 -0.9514 #00C0AF 0.7479 None
ARC 0.2517 -0.0094 #00A7FF 0.7478 None
RAPH1 -0.0534 -0.6031 #00BA38 0.7473 None
ASAP1 -0.0896 -0.6031 #00BA38 0.7472 None
C3 0.3945 0.7287 #39B600 0.7471 None
PARVG 0.5666 0.7287 #39B600 0.7470 None
AQP1 -0.0924 -0.9514 #00C0AF 0.7460 None
NCKAP1L 0.4157 0.7287 #39B600 0.7459 None
MIDN -0.0520 -0.6031 #00BA38 0.7458 None
MBD5 0.1943 0.0586 #FE6E8A 0.7456 3
DAAM1 0.1579 -0.0094 #00A7FF 0.7456 None
LEPREL1 0.2452 0.7916 #00C097 0.7450 None
PTPRO 0.1187 -0.0094 #00A7FF 0.7447 None
TBC1D8 0.0969 0.7287 #39B600 0.7444 None
ETV6 -0.1418 -0.6031 #00BA38 0.7444 None
GARNL3 -0.0933 -0.5355 #FD61D1 0.7425 None
RIMBP2 0.2615 -0.0094 #00A7FF 0.7419 None
VPS13B 0.3439 0.7916 #00C097 0.7418 None



Negative samples distribution


  • There is lots of noice, but in general genes with the lowest scores got their score increased and the genes with the highest scores decreased
negative_set = test_set %>% filter(!SFARI)

negative_set %>% mutate(diff = abs(prob-corrected_score)) %>% 
             ggplot(aes(prob, corrected_score, color = diff)) + geom_point(alpha=0.2) + scale_color_viridis() + 
             geom_abline(slope=1, intercept=0, color='gray', linetype='dashed') + 
             geom_smooth(color='#666666', alpha=0.5, se=TRUE, size=0.5) + coord_fixed() +
             xlab('Original probability') + ylab('Corrected probability') + theme_minimal() + theme(legend.position = 'none')

negative_set_table = negative_set %>% filter(!is.na(corrected_pred)) %>%
                     apply_labels(corrected_score = 'Corrected Probability', 
                                  corrected_pred = 'Corrected Class Prediction',
                                  pred = 'Original Class Prediction')

cro(negative_set_table$pred, list(negative_set_table$corrected_pred, total()))
 Corrected Class Prediction     #Total 
 FALSE   TRUE   
 Original Class Prediction 
   FALSE  7320 2410   9730
   TRUE  1426 3981   5407
   #Total cases  8746 6391   15137
cat(paste0('\n', round(100*mean(negative_set_table$corrected_pred == negative_set_table$pred)),
           '% of the genes maintained their original predicted class'))
## 
## 75% of the genes maintained their original predicted class
rm(negative_set_table)

Probability and Gene Significance


*The transparent verison of the trend line is the original trend line

The relation is the opposite as before, the higher the Gene Significance, the higher the score, with the lowest scores corresponding to under-expressed genes

negative_set %>% ggplot(aes(corrected_score, GS, color=MTcor)) + geom_point() + geom_smooth(method='gam', color='#666666') +
                 geom_line(stat='smooth', method='gam', color='#666666', alpha=0.5, size=1.2, aes(x=prob)) +
                 geom_hline(yintercept=mean(negative_set$GS), color='gray', linetype='dashed') +
                 scale_color_gradientn(colours=c('#F8766D','white','#00BFC4')) + xlab('Corrected Score') +
                 ggtitle('Relation between the Model\'s Corrected Score and Gene Significance') + theme_minimal()

Summarised version of score vs mean expression, plotting by module instead of by gene

The difference in the trend lines between this plot and the one above is that the one above takes all the points into consideration while this considers each module as an observation by itself, so the top one is strongly affected by big modules and the bottom one treats all modules the same

The transparent version of each point and trend lines are the original values and trends before the bias correction

  • Similar conclusions as above
plot_data = negative_set %>% filter(!is.na(corrected_pred)) %>% group_by(MTcor) %>% 
            summarise(mean = mean(prob), sd = sd(prob), new_mean = mean(corrected_score), new_sd = sd(corrected_score), n = n()) %>%
            mutate(MTcor_sign = ifelse(MTcor>0, 'Positive', 'Negative')) %>% left_join(original_dataset, by='MTcor') %>%
            dplyr::select(Module, MTcor, MTcor_sign, mean, new_mean, sd, new_sd, n) %>% distinct()
colnames(plot_data)[1] = 'ID'

ggplotly(plot_data %>% ggplot(aes(MTcor, new_mean, size=n, color=MTcor_sign)) + geom_point(aes(id = ID)) + 
         geom_smooth(method='loess', color='gray', se=FALSE) + geom_smooth(method='lm', se=FALSE) + 
         geom_point(aes(y=mean), alpha=0.3) + xlab('Module-Diagnosis correlation') + ylab('Mean Corrected Score by the Model') + 
         geom_line(stat='smooth', method='loess', color='gray', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         geom_line(stat='smooth', method='lm', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         theme_minimal() + theme(legend.position='none'))


Probability and mean level of expression


Check if correcting by gene also corrected by module: Yes, the bias seems to be removed completely, it may even be a bit negative now

mean_and_sd = data.frame(ID=rownames(datExpr), meanExpr=rowMeans(datExpr), sdExpr=apply(datExpr,1,sd))

plot_data = negative_set %>% filter(!is.na(corrected_pred)) %>% left_join(mean_and_sd, by='ID') %>% 
            left_join(original_dataset %>% mutate(ID=rownames(original_dataset)) %>% 
                      dplyr::select(ID, Module), by='ID')

plot_data2 = plot_data %>% group_by(Module) %>% summarise(meanExpr = mean(meanExpr), meanProb = mean(prob), 
                                                          new_meanProb = mean(corrected_score), n=n())

ggplotly(plot_data2 %>% ggplot(aes(meanExpr, new_meanProb, size=n)) + 
         geom_point(color=plot_data2$Module) + geom_point(color=plot_data2$Module, alpha=0.3, aes(y=meanProb)) + 
         geom_smooth(method='loess', se=TRUE, color='gray', alpha=0.1, size=0.7) + 
         geom_line(stat='smooth', method='loess', se=TRUE, color='gray', alpha=0.4, size=1.2, aes(y=meanProb)) +
         theme_minimal() + theme(legend.position='none') + xlab('Mean Expression') + ylab('Corrected Probability') +
         ggtitle('Mean expression vs corrected Model score by Module'))
rm(plot_data2, mean_and_sd)


Probability and SD of level of expression


The relation between SD and score became bigger than before

plot_data %>% ggplot(aes(sdExpr, corrected_score)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='lm', color='#999999', se=FALSE, alpha=1) + xlab('SD') + ylab('Corrected Probability') +
              geom_line(stat='smooth', method='lm', color='#999999', se=FALSE, alpha=0.4, size=1.5, aes(y=prob)) + 
              theme_minimal() + ggtitle('SD vs model probability by gene') + scale_x_sqrt()


Probability and lfc


For under-expressed genes, the relation between LFC and probability got inverted. The difference is quite big

For over-expressed genes, the trend didn’t change, it just got translated higher. Now in general, over-expressed genes have higher probabilities than over-expressed genes

plot_data = negative_set %>% left_join(DE_info %>% mutate(ID=rownames(DE_info)), by='ID')

plot_data %>% ggplot(aes(log2FoldChange, corrected_score)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='loess', color='gray', alpha=0.1) + xlab('LFC') + ylab('Corrected Probability') +
              geom_line(stat='smooth', method='loess', color='gray', alpha=0.4, size=1.5, aes(y=prob)) +
              theme_minimal() + ggtitle('LFC vs model probability by gene')

The trends for DE genes don’t seem to have changed much, they were just translated, upwards for over-expressed genes, and downwards for under-expressed genes

The big difference we see in the plot above is just the connection between the patterns for under and over-expressed genes

The only group that doesn’t seem to be affected much is the non-DE group of under-expressed genes

p1 = plot_data %>% filter(log2FoldChange<0) %>% mutate(DE = padj<0.05) %>% ggplot(aes(log2FoldChange, corrected_score, color=DE)) + geom_point(alpha=0.1) + 
                   geom_smooth(method='loess', alpha=0.1) + xlab('') + ylab('Corrected Probability') + 
                   ylim(c(min(plot_data$corrected_score), max(plot_data$corrected_score))) + 
                   geom_line(stat='smooth', method='loess', alpha=0.4, size=1.5, aes(y=prob, color = DE)) +
                   theme_minimal() + theme(legend.position = 'none', plot.margin=unit(c(1,-0.3,1,1), 'cm'))

p2 = plot_data %>% filter(log2FoldChange>=0) %>% mutate(DE = padj<0.05) %>% ggplot(aes(log2FoldChange, corrected_score, color=DE)) + geom_point(alpha=0.1) + 
                   geom_smooth(method='loess', alpha=0.1) + xlab('') + ylab('Corrected Probability') + ylab('') +
                   scale_y_continuous(position = 'right', limits = c(min(plot_data$corrected_score), max(plot_data$corrected_score))) +
                   geom_line(stat='smooth', method = 'loess', alpha=0.4, size=1.5, aes(y = prob, color = DE)) +
                   theme_minimal() + theme(plot.margin = unit(c(1,1,1,-0.3), 'cm'), axis.ticks.y = element_blank())

grid.arrange(p1, p2, nrow=1, top = 'LFC vs model probability by gene', bottom = 'LFC')

rm(p1, p2)


Probability and Module-Diagnosis correlation


The scores decreased for modules with negative correlation and increased for modules with positive correlation

module_score = negative_set %>% left_join(original_dataset %>% mutate(ID = rownames(original_dataset)), by='ID') %>%
               dplyr::select(ID, prob, corrected_score, Module, MTcor.x) %>% rename(MTcor = MTcor.x) %>% 
               left_join(data.frame(MTcor=unique(dataset$MTcor)) %>% arrange(by=MTcor) %>% 
                         mutate(order=1:length(unique(dataset$MTcor))), by='MTcor')

ggplotly(module_score %>% ggplot(aes(MTcor, corrected_score)) + geom_point(color=module_score$Module, aes(id=ID, alpha=corrected_score^4)) +
         geom_hline(yintercept=mean(module_score$corrected_score), color='gray', linetype='dotted') + 
         geom_line(stat='smooth', method = 'loess', color='gray', alpha=0.5, size=1.5, aes(x=MTcor, y=prob)) +
         geom_smooth(color='gray', method = 'loess', se = FALSE, alpha=0.3) + theme_minimal() + 
         xlab('Module-Diagnosis correlation') + ylab('Corrected Score'))



Conclusion


This bias correction makes bigger changes in the distribution of the probabilities than the post-processing one. Its main effect seems to be to reduce the importance of the under-expressed genes and increase the importance of over-expressed genes


Saving results

write.csv(test_set, file='./../Data/BC_weighting_approach_robust.csv', row.names = TRUE)




Session info

sessionInfo()
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 18.04.4 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
## 
## locale:
##  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
##  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
##  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] expss_0.10.2       polycor_0.7-10     car_3.0-7          carData_3.0-3     
##  [5] ROCR_1.0-7         gplots_3.0.3       caret_6.0-86       lattice_0.20-41   
##  [9] Rtsne_0.15         biomaRt_2.40.5     reshape2_1.4.3     corrplot_0.84     
## [13] RColorBrewer_1.1-2 gridExtra_2.3      viridis_0.5.1      viridisLite_0.3.0 
## [17] plotly_4.9.2       knitr_1.28         forcats_0.5.0      stringr_1.4.0     
## [21] dplyr_0.8.5        purrr_0.3.3        readr_1.3.1        tidyr_1.0.2       
## [25] tibble_3.0.0       ggplot2_3.3.0      tidyverse_1.3.0   
## 
## loaded via a namespace (and not attached):
##   [1] readxl_1.3.1                backports_1.1.5            
##   [3] Hmisc_4.4-0                 plyr_1.8.6                 
##   [5] lazyeval_0.2.2              splines_3.6.3              
##   [7] crosstalk_1.1.0.1           BiocParallel_1.18.1        
##   [9] GenomeInfoDb_1.20.0         digest_0.6.25              
##  [11] foreach_1.5.0               htmltools_0.4.0            
##  [13] gdata_2.18.0                fansi_0.4.1                
##  [15] magrittr_1.5                checkmate_2.0.0            
##  [17] memoise_1.1.0               cluster_2.1.0              
##  [19] openxlsx_4.1.4              annotate_1.62.0            
##  [21] recipes_0.1.10              modelr_0.1.6               
##  [23] gower_0.2.1                 matrixStats_0.56.0         
##  [25] prettyunits_1.1.1           jpeg_0.1-8.1               
##  [27] colorspace_1.4-1            blob_1.2.1                 
##  [29] rvest_0.3.5                 haven_2.2.0                
##  [31] xfun_0.12                   crayon_1.3.4               
##  [33] RCurl_1.98-1.1              jsonlite_1.6.1             
##  [35] genefilter_1.66.0           survival_3.1-11            
##  [37] iterators_1.0.12            glue_1.3.2                 
##  [39] gtable_0.3.0                ipred_0.9-9                
##  [41] zlibbioc_1.30.0             XVector_0.24.0             
##  [43] DelayedArray_0.10.0         shape_1.4.4                
##  [45] BiocGenerics_0.30.0         abind_1.4-5                
##  [47] scales_1.1.0                DBI_1.1.0                  
##  [49] Rcpp_1.0.4                  xtable_1.8-4               
##  [51] progress_1.2.2              htmlTable_1.13.3           
##  [53] foreign_0.8-75              bit_1.1-15.2               
##  [55] Formula_1.2-3               stats4_3.6.3               
##  [57] lava_1.6.7                  prodlim_2019.11.13         
##  [59] glmnet_3.0-2                htmlwidgets_1.5.1          
##  [61] httr_1.4.1                  acepack_1.4.1              
##  [63] ellipsis_0.3.0              farver_2.0.3               
##  [65] pkgconfig_2.0.3             XML_3.99-0.3               
##  [67] nnet_7.3-13                 dbplyr_1.4.2               
##  [69] locfit_1.5-9.4              labeling_0.3               
##  [71] tidyselect_1.0.0            rlang_0.4.5                
##  [73] AnnotationDbi_1.46.1        munsell_0.5.0              
##  [75] cellranger_1.1.0            tools_3.6.3                
##  [77] cli_2.0.2                   generics_0.0.2             
##  [79] RSQLite_2.2.0               broom_0.5.5                
##  [81] evaluate_0.14               yaml_2.2.1                 
##  [83] ModelMetrics_1.2.2.2        bit64_0.9-7                
##  [85] fs_1.4.0                    zip_2.0.4                  
##  [87] caTools_1.18.0              nlme_3.1-147               
##  [89] xml2_1.2.5                  compiler_3.6.3             
##  [91] rstudioapi_0.11             png_0.1-7                  
##  [93] curl_4.3                    e1071_1.7-3                
##  [95] reprex_0.3.0                geneplotter_1.62.0         
##  [97] stringi_1.4.6               highr_0.8                  
##  [99] Matrix_1.2-18               vctrs_0.2.4                
## [101] pillar_1.4.3                lifecycle_0.2.0            
## [103] data.table_1.12.8           bitops_1.0-6               
## [105] GenomicRanges_1.36.1        latticeExtra_0.6-29        
## [107] R6_2.4.1                    KernSmooth_2.23-16         
## [109] rio_0.5.16                  IRanges_2.18.3             
## [111] codetools_0.2-16            MASS_7.3-51.5              
## [113] gtools_3.8.2                assertthat_0.2.1           
## [115] SummarizedExperiment_1.14.1 DESeq2_1.24.0              
## [117] withr_2.1.2                 S4Vectors_0.22.1           
## [119] GenomeInfoDbData_1.2.1      mgcv_1.8-31                
## [121] parallel_3.6.3              hms_0.5.3                  
## [123] grid_3.6.3                  rpart_4.1-15               
## [125] timeDate_3043.102           class_7.3-16               
## [127] rmarkdown_2.1               pROC_1.16.2                
## [129] base64enc_0.1-3             Biobase_2.44.0             
## [131] lubridate_1.7.4